import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SegmentAllocator;
import java.lang.foreign.ValueLayout;
import java.util.Arrays;

// create和push返回的对象必须用try-with-resources保证释放.
// MemStack2对象及其分配出来的MemorySegment只能在当前线程的栈上(包括方法参数)传递,
// 绝对不能传到堆上(对象/类的字段), 以确保生命期在所属的MemStack2对象内,且没有并发访问.
public class MemStack2 implements SegmentAllocator, AutoCloseable {
	private static final ThreadLocal<MemStack2> tl = new ThreadLocal<>();
	private final Arena arena = Arena.ofConfined();
	private final MemorySegment memSeg; // 当前正在分配的内存段
	private int offset; // memSeg已经分配的偏移
	private int[] frames = new int[4];
	private int frameIdx;

	private MemStack2() {
		this(64 * 1024);
	}

	private MemStack2(int stackSize) {
		memSeg = arena.allocate(stackSize);
	}

	public int getAllocSize() {
		return offset;
	}

	public int getMaxSize() {
		return (int)memSeg.byteSize();
	}

	public static MemStack2 create(int stackSize) {
		var ms = tl.get();
		if (ms != null)
			throw new IllegalStateException(MemStack2.class.getName() + " already created");
		tl.set(ms = new MemStack2(stackSize));
		return ms;
	}

	public static MemStack2 push() {
		var ms = tl.get();
		if (ms == null)
			tl.set(ms = new MemStack2());
		else {
			int frameIdx = ms.frameIdx;
			var frames = ms.frames;
			int frameCount = frames.length;
			if (frameIdx == frameCount)
				ms.frames = frames = Arrays.copyOf(frames, frameCount << 1);
			frames[frameIdx++] = ms.offset;
			ms.frameIdx = frameIdx;
		}
		return ms;
	}

	public void pop() {
		int frameIdx = this.frameIdx;
		if (frameIdx != 0) {
			offset = frames[--frameIdx];
			this.frameIdx = frameIdx;
		} else {
			arena.close();
			tl.remove();
		}
	}

	@Override
	public MemorySegment allocate(long byteSize, long byteAlignment) {
		if (byteSize < 0)
			throw new IllegalArgumentException("negative byteSize: " + byteSize);
		if (byteAlignment <= 0 || (byteAlignment & (byteAlignment - 1)) != 0)
			throw new IllegalArgumentException("invalid byteAlignment: " + byteAlignment);
		var seg = memSeg;
		long off = offset;
		var mask = byteAlignment - 1;
		var lowBits = (seg.address() + off) & mask;
		if (lowBits != 0)
			off += byteAlignment - lowBits;
		var end = off + byteSize;
		if (end > seg.byteSize())
			throw new StackOverflowError(getClass().getName() + " overflow: " + seg.byteSize());
		offset = (int)end;
		return seg.asSlice(off, byteSize);
	}

	@Override
	public void close() {
		pop();
	}

	public static void main(String[] args) {
		try (var stack = MemStack2.push()) {
			var v1 = stack.allocateFrom(ValueLayout.JAVA_FLOAT, 12345);
			var v2 = stack.allocateFrom(ValueLayout.JAVA_LONG, 67890);
			System.out.println(stack.memSeg.address());
			System.out.println(v1.address());
			System.out.println(v2.address());
			System.out.println(stack.offset);
			System.out.println(stack.getAllocSize());
		}
	}
}
